import torch

t1 = torch.rand(2, 1, 5)
print(t1)

# 将t1广播到指定的shape，注意只能将dim==1的位置进行广播，或者将最左侧的位置增加维度。
# 简单说，就是只能对dim==1或左侧空dim进行广播
t2 = torch.broadcast_to(t1, [2, 3, 5])
print(t2)

t3 = torch.broadcast_to(t1, [4, 2, 3, 5])
print(t3)

# 加减法支持广播
a0 = torch.rand(2, 3)
a1 = torch.rand(2, 1)
a2 = a0 - a1
print(f'a0={a0}')
print(f'a1={a1}')
print(f'a2={a2}')

# 点乘，支持广播
m0 = torch.rand(2, 3)
m1 = torch.rand(2, 1)
m2 = m0 * m1
print(f'm0={m0}')
print(f'm1={m1}')
print(f'm2={m2}')

# elewise 除法，支持广播
d0 = torch.rand(2, 3)
d1 = torch.rand(2, 1)
d2 = d0 / d1
print(f'd0={d0}')
print(f'd1={d1}')
print(f'd2={d2}')

